import torch
import torchvision
from torchvision.models import ResNet50_Weights
# 加载预训练的ResNet50模型
model = torchvision.models.resnet50(weights=ResNet50_Weights.IMAGENET1K_V2)
num_classes = 2
# 将全连接层的输出维度替换为num_classes
in_features = model.fc.in_features
model.fc = torch.nn.Linear(in_features, num_classes)
device = "cuda"
model.to(device)
num_epochs = 20
lr = 1e-5
batch_size = 16
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
